Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch/jax Dataloader support #55

Merged
merged 17 commits into from
Apr 5, 2024
Merged

torch/jax Dataloader support #55

merged 17 commits into from
Apr 5, 2024

Conversation

shenoynikhil
Copy link
Collaborator

@shenoynikhil shenoynikhil commented Mar 24, 2024

Currently, getting a torch_geometric Dataloader is quite complicated since it requires multiple steps in the process. This PR addresses those concerns by introducing,

  1. A method to get torch tensors from getitem call similar to with_format of huggingface datasets Link. This now allows us to get numpy (default), torch or jax arrays from __getitem__ call.

For example,

from openqdc import Dummy

import numpy as np; ds = Dummy(); print (isinstance(ds[0]['positions'], np.ndarray)) # prints True
import torch; ds = Dummy(array_format="torch"); print (isinstance(ds[0]['positions'], torch.Tensor)) # prints True
import jax; ds = Dummy(array_format="jax"); print (isinstance(ds[0]['positions'], jax.numpy.ndarray)) # prints True
  1. Added an option to add a transform. In cases, where we want to use the torch_geometric Data object, instead of the sklearn Bunch from getitem, it might be convenient to use a function on the data bunch returned.
from torch_geometric.data import Data
def custom_transform(bunch): 
    return Data(z=bunch.atomic_numbers, pos=bunch.positions, e=bunch.energies, f=bunch.forces)
ds = Dummy(array_format="torch", transform=custom_transform)
print (isinstance(ds[0], Data)) # prints True
  1. Last but not the least, in case you want to create a pytorch geometric dataloader, you can do so like this. I still haven't created a method explicitly to do so because the amount of effort now would be reduced quite a bit. If you still think it's valuable for the user to get a dataloader, I can implement it.

Note: I looked through huggingface datasets, they too do not have any such method to get a dataloader from their datasets.

from torch_geometric.data import Data, DataLoader
def custom_transform(bunch): return Data(z=bunch.atomic_numbers, pos=bunch.positions, e=bunch.energies, f=bunch.forces)
ds = Dummy(array_format="torch", transform=custom_transform)
dl = DataLoader(ds, batch_size=4)
batch = next(iter(dl))

TODOs

  • Currently I have only added tests for the array_format functionality. Will also add a test for the transform.

@prtos @FNTwin

Checklist:

  • Was this PR discussed in a issue? It is recommended to first discuss a new feature into a GitHub issue before opening a PR.
  • Add tests to cover the fixed bug(s) or the new introduced feature(s) (if appropriate).
  • Update the API documentation is a new function is added or an existing one is deleted.

@shenoynikhil shenoynikhil changed the base branch from main to develop March 24, 2024 07:21
@shenoynikhil shenoynikhil linked an issue Mar 24, 2024 that may be closed by this pull request
@shenoynikhil shenoynikhil changed the title [WIP] torch/jax Dataloader support torch/jax Dataloader support Mar 25, 2024
@shenoynikhil shenoynikhil self-assigned this Mar 25, 2024
Copy link
Collaborator

@FNTwin FNTwin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also add a simple as_dataloader function to provide out of the box utilities at this point.
Nevertheless it is a good PR, I could nitpick/disagree a bit on the implementation as I wanted to do some dynamic inheritance based on the available packages to automatically define the return object type but simpler is better down the line

openqdc/datasets/base.py Outdated Show resolved Hide resolved
openqdc/datasets/base.py Outdated Show resolved Hide resolved
openqdc/datasets/base.py Outdated Show resolved Hide resolved
openqdc/datasets/base.py Outdated Show resolved Hide resolved
openqdc/datasets/base.py Show resolved Hide resolved
@shenoynikhil shenoynikhil changed the base branch from develop to release April 3, 2024 17:44
Copy link
Collaborator

@FNTwin FNTwin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All good, with the conversion to tensor getting a dataset into a dataloader should be a one line (for real this time), but I still think that we should have a dummy as_iter method to return a default dataloader.
In any case thank you for the bug fixing and the work! I'll probably try to clean up the conditional import of torch and jax on my end somehow but the PR is 🔥

@shenoynikhil shenoynikhil merged commit ac299a5 into release Apr 5, 2024
5 checks passed
@shenoynikhil shenoynikhil deleted the dataloader branch April 5, 2024 00:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Getting Dataloaders Easily
2 participants